'''
Generate the plots of fig. 2
'''
import numpy as np
import matplotlib.pyplot as plt
from functions import updated_signal


def get_dist(filename, dx):
	data = np.loadtxt(filename)
	dist, bins = np.histogram( np.ndarray.flatten(data), bins=np.arange(0, 10001, dx))
	for i in range(1, dist.size-1):
		dist[i] = (dist[i-1]+dist[i]+dist[i+1])/3.
	return bins[0:-1]+dx/2, dist/(np.sum(dist)*dx)


n = 30
hexagonal = True
nsim=50
vegf = np.array([0., 10., 50., 100., 150., 200., 220., 240., 260., 280., 300., 500., 1000., 2000., 3000., 5000., 10000.])

x = np.zeros((n ,n))
y = np.zeros((n ,n))
for i in range(n):
	y[i ,:] = np.sqrt(3 ) *np.arange(1 , n +1 ,1 ) /2.
for i in range(n):
	if i % 2 != 0:
		x[: ,i] = np.arange(1 , n +1 ,1)
	else:
		x[: ,i] = np.arange(1 , n +1 ,1) + 0.5


def dis_score(x, hexagonal=False):
	'''
	x: matrix with delta levels
	'''

	# step 1: enlarge D matrix to (n+1), (n+1)
	n = x[0].size
	M = np.zeros((n + 2, n + 2))  # dummy matrix to include boundary conditions
	M[1:n + 1, 1:n + 1] = x
	M[1:n + 1, 0] = M[1:n + 1, n]  # first column equals second to last column
	M[1:n + 1, n + 1] = M[1:n + 1, 1]  # last column equals second column
	M[0, 1:n + 1] = M[n, 1:n + 1]  # first row equals second to last row
	M[n + 1, 1:n + 1] = M[1, 1:n + 1]  # last row equals second row
	M[0, 0] = M[n, n]
	M[0, n + 1] = M[n, 1]
	M[n + 1, n + 1] = M[1, 1]
	M[n + 1, 0] = M[1, n]

	m = n+2

	# next two steps are to compute external signal
	# step 2: enlarge M matrix to (n+2), (n+2)
	n2 = M[0].size
	M1 = np.zeros((m + 2, m + 2))  # dummy matrix to include boundary conditions
	M1[1:m + 1, 1:m + 1] = M
	M1[1:m + 1, 0] = M1[1:m + 1, m]  # first column equals second to last column
	M1[1:m + 1, m + 1] = M1[1:m + 1, 1]  # last column equals second column
	M1[0, 1:m + 1] = M1[m, 1:m + 1]  # first row equals second to last row
	M1[m + 1, 1:m + 1] = M1[1, 1:m + 1]  # last row equals second row
	M1[0, 0] = M1[m, m]
	M1[0, m + 1] = M1[m, 1]
	M1[m + 1, m + 1] = M1[1, 1]
	M1[m + 1, 0] = M1[1, m]

	# step 3: compute external signal on (n+1), (n+1) matrix
	I = np.zeros((m, m))
	#print(x.shape, M.shape, M1.shape, I.shape)
	for i in range(1, m + 1):
		for j in range(1, m + 1):
			I[i - 1][j - 1] = (M1[i + 1][j] + M1[i - 1][j] + M1[i][j + 1] + M1[i][j - 1])
			if (i - 1) % 2 != 0: # switch == and != because starts from a different row
				I[i - 1][j - 1] = I[i - 1][j - 1] + M1[i + 1][j + 1] + M1[i - 1][j + 1]
			else:
				I[i - 1][j - 1] = I[i - 1][j - 1] + M1[i + 1][j - 1] + M1[i - 1][j - 1]
	I = I/6.

	score = M-I

	err = 0
	ntip = 0

	for i in range(1, n + 1):
		for j in range(1, n + 1):

			if score[i][j]>0. and (i - 1) % 2 == 0:
				neigh = np.array([ score[i + 1][j], score[i - 1][j], score[i][j + 1], score[i][j - 1],
								   score[i + 1][j + 1], score[i - 1][j + 1] ])
				err = err + np.sum( neigh>0 )
				ntip = ntip + 1

			elif score[i][j]>0. and (i - 1) % 2 != 0:
				neigh = np.array([ score[i + 1][j], score[i - 1][j], score[i][j + 1], score[i][j - 1],
								   score[i + 1][j - 1], score[i - 1][j - 1] ])
				err = err + np.sum(neigh > 0)
				ntip = ntip + 1

	return ntip, err/2.



fig = plt.figure(figsize=(10,12))

ax1 = plt.subplot2grid( (3,2), (1,0), rowspan=1, colspan=2 )
ax1.set_yscale('log')
dx = 10.
vegf_plot = np.array([10., 100., 500., 1000.,10000.])
colors = ['k', 'r', 'b', 'g', 'm']
for i in range(vegf_plot.size):
	file = 'sim_data/delta_special_' + str(int(vegf_plot[i])) + '_0.txt'
	delta, dist = get_dist(file, dx)
	plt.plot(delta, dist, '-', color=colors[i], label='$VEGF_{EXT}$ = ' + str(int(vegf_plot[i])))
plt.xlim([0,7000])
plt.ylim([10**(-5), 0.1])
plt.xlabel('Delta (molecules)')
plt.ylabel('Probability density')
plt.legend(loc='upper right', ncol=5)


# plot pattern
ax2 = plt.subplot2grid( (3,2), (0,0), rowspan=1, colspan=1 )
data = np.loadtxt('sim_data/delta_special_240_0.txt')
dist = data[0:n]
im = ax2.scatter(x, y, c=dist, s=80000/(n*n), cmap='Reds', vmin=0, vmax=4000 ,linewidths=0.75, edgecolors='k', marker=(6, 0, 0))
cbar = fig.colorbar(im, ax=ax2 , orientation='vertical')
cbar.set_label('Delta (Molecules)')
plt.xlim([1 , n +0.5])
plt.ylim([np.sqrt(3 ) /2 ,np.sqrt(3 ) * n /2])
plt.xlabel('Cells')
plt.ylabel('Cells')
plt.xticks([])
plt.yticks([])

# plot 2D map of states
delta = np.loadtxt('sim_data/delta_special_240_0.txt')
notch = np.loadtxt('sim_data/notch_special_240_0.txt')
bin = np.logspace(0,4, num=50)
x = np.logspace(0,4, num=49)
dist, xbin, ybin = np.histogram2d( np.ndarray.flatten(notch), np.ndarray.flatten(delta), bins=np.array([bin, bin]) )
dist = np.log10(dist + 1)

for i in range(1,x.size-1):
	for j in range(1,x.size-1):
		dist[i][j] = (dist[i][j]+dist[i-1][j]+dist[i+1][j]+dist[i][j-1]+dist[i][j+1])/5.

ax = plt.subplot2grid( (3,2), (0,1), rowspan=1, colspan=1 )
ax.set_xscale('log')
ax.set_yscale('log')
pt = plt.contourf(x, x, np.transpose(dist), 11, cmap='Reds' )
cbar= plt.colorbar(pt, ax=ax)
cbar.set_label('$log(p(N,D)+1)$')
plt.xlabel('Notch (molecules)')
plt.ylabel('Delta (molecules)')
plt.xlim([300, 10000])
plt.ylim([20, 10000])

# quantify pattern disorder
dis = np.zeros(vegf.size)
std_dis = np.zeros(vegf.size)
for p in range(vegf.size):
	file = 'sim_data/delta_special_' + str(int(vegf[p])) + '_0.txt'
	data = np.loadtxt(file)

	ntip = np.zeros(nsim)
	mist = np.zeros(nsim)
	for i in range(nsim):
		delta = data[i * n:(i + 1) * n]
		ntip[i], mist[i] = dis_score(delta, hexagonal=True)
	dis[p] = np.mean(mist)/np.mean(ntip)
	std_dis[p] = np.std(mist)/np.mean(ntip)


ax2 = plt.subplot2grid( (3,2), (2,1), rowspan=1, colspan=1 )
ax2.set_xscale('log')
plt.errorbar( vegf[1:], dis[1:], yerr=std_dis[1:], fmt='ko-', ecolor='k' )
plt.xlim([10, 10000])
plt.ylim([0, 0.8])
plt.xlabel('$VEGF_{EXT}$')
plt.ylabel('Disorder index')

bin = np.arange(-10005, 10010, dx)
m = 1000
x = bin[0:-1]+dx/2.
tip_frac = np.zeros(vegf.size)
tip_err = np.zeros(vegf.size)
stalk_frac = np.zeros(vegf.size)
stalk_err = np.zeros(vegf.size)
for i in range(vegf.size):
	file = 'sim_data/delta_special_' + str(int(vegf[i])) + '_0.txt'
	data = np.loadtxt(file)
	tip_num = np.zeros(nsim)
	stalk_num = np.zeros(nsim)
	for j in range(nsim):
		delta = data[j * n:(j + 1) * n]
		score = delta - updated_signal(delta, hexagonal=True)
		score = np.ndarray.flatten(score)
		tip_num[j] = (score[score>0.].size)/ float(n * n)
	stalk_num = 1. - tip_num
	tip_frac[i], tip_err[i] = np.mean(tip_num), np.std(tip_num)
	stalk_frac[i], stalk_err[i] = np.mean(stalk_num), np.std(stalk_num)


ax5 = plt.subplot2grid( (3,2), (2,0), rowspan=1, colspan=1 )
ax5.set_xscale('log')
plt.errorbar( vegf[1:], tip_frac[1:], yerr=tip_err[1:], fmt='ko-', ecolor='k', label='Model' )
# plot experimental data
tip_V10, tip_V100 = np.array([0.261, 0.145]), np.array([0.34, 0.305])
stalk_V10, stalk_V100 = 1-tip_V10, 1-tip_V100
mean10, std10 = np.mean(tip_V10), np.std(tip_V10)
mean100, std100 = np.mean(tip_V100), np.std(tip_V100)
plt.plot( [10,10000], [mean10, mean10], 'r--', label='Parental vessel - $VEGF_{EXT}$=10ng/ml' )
plt.fill_between( [10,10000], [mean10+std10, mean10+std10], [mean10-std10, mean10-std10], color='r', alpha=0.4 )
plt.plot( [10,10000], [mean100, mean100], 'g--', label='Parental vessel - $VEGF_{EXT}$=100ng/ml' )
plt.fill_between( [10,10000], [mean100+std100, mean100+std100], [mean100-std100, mean100-std100], color='g', alpha=0.4 )
plt.ylim([0., 0.45])
plt.xlim([10, 10000])
plt.xlabel('$VEGF_{EXT}$')
plt.ylabel('Tip cell fraction')
plt.legend(loc='lower right')

plt.tight_layout()
plt.savefig('Tip_frac.pdf', format='pdf', dpi=300)



